import argparse
import torch
import numpy as np
from tqdm import *
import torch.optim as optim
from utils import KpiReader
from models import StackedVAGT


class Trainer(object):
    def __init__(self, vagt, train, trainloader, log_path='log_trainer', log_file='loss', epochs=20,
                 batch_size=1024, learning_rate=0.001, checkpoints='kpi_model.path', checkpoints_interval=1,
                 device=torch.device('cuda:0')):
        self.trainloader = trainloader
        self.train = train
        self.log_path = log_path
        self.log_file = log_file
        self.start_epoch = 0
        self.epochs = epochs
        self.device = device
        self.batch_size = batch_size
        self.vagt = vagt
        self.vagt.to(device)
        self.learning_rate = learning_rate
        self.checkpoints = checkpoints
        self.checkpoints_interval = checkpoints_interval
        print('Model parameters: {}'.format(self.vagt.parameters()))
        self.optimizer = optim.Adam(self.vagt.parameters(), self.learning_rate)
        self.epoch_losses = []
        self.loss = {}

    def save_checkpoint(self, epoch, checkpoints):
        torch.save({'epoch': epoch + 1,
                    'beta': self.vagt.beta,
                    'state_dict': self.vagt.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'losses': self.epoch_losses},
                    checkpoints + '_epochs{}.pth'.format(epoch+1))

    def load_checkpoint(self, start_ep, checkpoints):
        try:
            print("Loading Chechpoint from ' {} '".format(checkpoints+'_epochs{}.pth'.format(start_ep)))
            checkpoint = torch.load(checkpoints+'_epochs{}.pth'.format(start_ep))
            self.start_epoch = checkpoint['epoch']
            self.vagt.beta = checkpoint['beta']
            self.vagt.load_state_dict(checkpoint['state_dict'])
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.epoch_losses = checkpoint['losses']
            print("Resuming Training From Epoch {}".format(self.start_epoch))
            self.start_epoch = 0
        except:
            print("No Checkpoint Exists At '{}', Starting Fresh Training".format(checkpoints))
            self.start_epoch = 0

    def train_model(self):
        self.vagt.train()
        for epoch in range(self.start_epoch, self.epochs):
            losses = []
            llhs = []
            kld_zs = []
            print("Running Epoch : {}".format(epoch + 1))
            for i, dataitem in tqdm(enumerate(self.trainloader, 1)):
                _, _, data = dataitem
                batch_size = data.size(0)
                data = data.to(self.device)
                self.optimizer.zero_grad()

                z_posterior_forward, z_mean_posterior_forward, z_logvar_posterior_forward, z_mean_prior_forward, \
                z_logvar_prior_forward, x_mu, x_logsigma = self.vagt(data)

                llh = self.vagt.loss_LLH(data, x_mu, x_logsigma) / batch_size
                kld_z = 0
                kld_z += self.vagt.loss_KL(z_mean_posterior_forward, z_logvar_posterior_forward,
                                           z_mean_prior_forward, z_logvar_prior_forward) / batch_size

                loss = -llh + self.vagt.beta * kld_z
                loss.backward()
                self.optimizer.step()
                losses.append(loss.item())
                llhs.append(llh.item())
                kld_zs.append(kld_z.item())
            meanloss = np.mean(losses)
            meanllh = np.mean(llhs)
            meanz = np.mean(kld_zs)
            self.epoch_losses.append(meanloss)
            print("Epoch {} : Average Loss: {} Loglikelihood: {} KL of z: {}, Beta: {}".format(
                epoch + 1, meanloss, meanllh, meanz, self.vagt.beta))
            self.loss['Epoch'] = epoch + 1
            self.loss['Avg_loss'] = meanloss
            self.loss['Llh'] = meanllh
            self.loss['KL_z'] = meanz
            if (self.checkpoints_interval > 0
                    and (epoch + 1) % self.checkpoints_interval == 0):
                self.save_checkpoint(epoch, self.checkpoints)

            if (epoch + 1) % 1 == 0:
                self.vagt.beta = np.minimum((self.vagt.beta + 0.01) * np.exp(self.vagt.anneal_rate * (epoch + 1)),
                                             self.vagt.max_beta)

        print("Training is complete!")


def main():
    import os
    os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

    parser = argparse.ArgumentParser()

    parser.add_argument('--gpu_id', type=int, default=None)

    parser.add_argument('--dataset_path', type=str, default='')
    parser.add_argument('--batch_size', type=int, default=None)
    parser.add_argument('--num_workers', type=int, default=None)
    parser.add_argument('--x_dim', type=int, default=None)
    parser.add_argument('--win_len', type=int, default=None)

    parser.add_argument('--c_dim', type=int, default=None)
    parser.add_argument('--z_dim', type=int, default=None)
    parser.add_argument('--h_dim', type=int, default=None)
    parser.add_argument('--n_head', type=int, default=None)
    parser.add_argument('--layer_xz', type=int, default=None)
    parser.add_argument('--layer_h', type=int, default=None)
    parser.add_argument('--q_len', type=int, default=None)
    parser.add_argument('--embd_h', type=int, default=None)
    parser.add_argument('--embd_s', type=int, default=None)
    parser.add_argument('--vocab_len', type=int, default=None)

    parser.add_argument('--dropout', type=float, default=None)
    parser.add_argument('--learning_rate', type=float, default=None)
    parser.add_argument('--beta', type=float, default=None)
    parser.add_argument('--max_beta', type=float, default=None)
    parser.add_argument('--anneal_rate', type=float, default=None)
    parser.add_argument('--epochs', type=int, default=None)
    parser.add_argument('--start_epoch', type=int, default=None)
    parser.add_argument('--checkpoints_interval', type=int, default=None)
    parser.add_argument('--checkpoints_path', type=str, default='')
    parser.add_argument('--checkpoints_file', type=str, default='')
    parser.add_argument('--log_path', type=str, default='')
    parser.add_argument('--log_file', type=str, default='')

    args = parser.parse_args()

    if torch.cuda.is_available() and args.gpu_id >= 0:
        device = torch.device('cuda:%d' % args.gpu_id)
    else:
        device = torch.device('cpu')

    if not os.path.exists(args.dataset_path):
        raise ValueError('Unknown dataset path: {}'.format(args.dataset_path))

    if not os.path.exists(args.log_path):
        os.makedirs(args.log_path)

    if not os.path.exists(args.checkpoints_path):
        os.makedirs(args.checkpoints_path)

    if args.checkpoints_file == '':
        args.checkpoints_file = 'c_dim-{}_x_dim-{}_z_dim-{}_h_dim-{}_layer_xz-{}_layer_h-{}_embd_h-{}_n_head-{}_'\
                                'win_len-{}_q_len-{}_vocab_len-{}'.format(args.c_dim, args.x_dim, args.z_dim, args.h_dim,
                                                                          args.layer_xz, args.layer_h, args.embd_h,
                                                                          args.n_head, args.win_len, args.q_len,
                                                                          args.vocab_len)
    if args.log_file == '':
        args.log_file = 'c_dim-{}_x_dim-{}_z_dim-{}_h_dim-{}_layer_xz-{}_layer_h-{}_embd_h-{}_n_head-{}_win_len-{}_'\
                        'q_len-{}_vocab_len-{}'.format(args.c_dim, args.x_dim, args.z_dim, args.h_dim, args.layer_xz,
                                                       args.layer_h, args.embd_h, args.n_head, args.win_len,
                                                       args.q_len, args.vocab_len)

    kpi_value_train = KpiReader(args.dataset_path)
    train_loader = torch.utils.data.DataLoader(kpi_value_train, batch_size=args.batch_size,
                                               shuffle=True, num_workers=args.num_workers)

    stackedvagt = StackedVAGT(layer_xz=args.layer_xz, layer_h=args.layer_h, n_head=args.n_head, c_dim=args.c_dim, x_dim=args.x_dim,
                              z_dim=args.z_dim, h_dim=args.h_dim, embd_h=args.embd_h, embd_s=args.embd_s,
                              beta=args.beta, q_len=args.q_len, vocab_len=args.vocab_len, win_len=args.win_len,
                              dropout=args.dropout, anneal_rate=args.anneal_rate, max_beta=args.max_beta,
                              device=device).to(device)
    names = []
    for name, parameters in stackedvagt.named_parameters():
        names.append(name)
        print(name, ':', parameters, parameters.size())

    trainer = Trainer(stackedvagt, kpi_value_train, train_loader, log_path=args.log_path, epochs=args.epochs,
                      log_file=args.log_file, batch_size=args.batch_size, learning_rate=args.learning_rate,
                      checkpoints=os.path.join(args.checkpoints_path, args.checkpoints_file),
                      checkpoints_interval=args.checkpoints_interval, device=device)
    trainer.load_checkpoint(args.start_epoch, trainer.checkpoints)
    trainer.train_model()


if __name__ == '__main__':
    import warnings
    warnings.filterwarnings("ignore")
    main()
